import snakes.plugins
snakes.plugins.load(['gv', 'ops'], 'snakes.nets', 'nets')
from nets import *

###############################################################################
###############################################################################
############################### INIT ##########################################

# entities: tuple of name of the entities, initial level, tuple of decays 0
# denotes unbounded decay (omega)
#entities = ( ('B',4, (0,2,2,2,3)), ('P',0, (0,0)),
#('C',0, (0,0)), ('G',0, (0,0)))
#entities = ( ('Sugar',1, (0,2)), ('Aspartame',0, (0,2)),
#('Glycemia',2, (0,2,2,2)), ('Glucagon',0, (0,2)), ('Insulin',0,(0,2,2)))

entities = (('lacI', 0, (0, 2)), ('tetR', 0, (0, 2)), ('cI', 1, (0, 2)),
             ('GFP', 0, (0, 0)), ('gen', 0, (0, 0, 0, 0, 0, 0)))

# Activities: Tuple of (activators, inhibitors, results, duration)
# activators, inhibitors are dictionaries of pairs (entity, level)
# results are dictionaries of pairs (entity, +z)

# potential activities
#potential = ((dict([('P',0)]),dict([('P',1)]),dict([('P',1)]),0),
#             (dict([('P',1)]),dict(),dict([('P',-1)]),0),
#             (dict([('C',0)]),dict([('C',1)]),dict([('C',1)]),0),
#             (dict([('C',1)]),dict(),dict([('C',-1)]),0),
#             (dict([('G',0)]),dict([('G',5)]),dict([('G',1)]),0),
#             (dict([('G',5)]),dict(),dict([('G',-5)]),0) )

# potential = ((dict([('Sugar',1)]),dict(),dict([('Insulin',1),
#               ('Glycemia',1)]),0),
#              (dict([('Aspartame',1)]),dict(),dict([('Insulin',1)]),0),
#              (dict(),dict([('Glycemia',1)]),dict([('Glucagon',1)]),0),
#              (dict([('Glycemia',3)]),dict(),dict([('Insulin',1)]),0),
#              (dict([('Insulin',2)]),dict(),dict([('Glycemia',-1)]),0),
#              (dict([('Insulin',1),('Glycemia',3)]), dict(),
#               dict([('Glycemia',-1)]),0),
#              (dict([('Insulin',1)]),dict([('Glycemia',2)]),
#               dict([('Glycemia',-1)]),0),
#              (dict([('Glucagon',1)]),dict(),dict([('Glycemia',+1)]),0) )

potential = ((dict(), dict([('lacI', 1)]), dict([('tetR', 1)]), 2),
             (dict(), dict([('tetR', 1)]), dict([('cI', 1)]), 2),
             (dict(), dict([('cI', 1)]), dict([('lacI', 1)]), 2))

# obligatory activities
#obligatory = ( (dict([('P',1)]),dict(),dict([('B',1)]),1),
#               (dict([('C',1)]),dict(),dict([('B',-1)]),3),
#               (dict([('G',1)]),dict(),dict([('B',-2)]),3))

obligatory = ((dict([('lacI', 1)]), dict([('GFP', 1)]), dict([('GFP', 1)]), 0),
              (dict([('lacI', 0), ('GFP', 1)]),
               dict([('lacI', 1)]), dict([('GFP', -1)]), 0),
              (dict([('gen', 0)]), dict([('gen', 1)]), dict([('gen', 1)]), 1),
              (dict([('gen', 1)]), dict(), dict([('gen', -1)]), 1))


###############################  END ##########################################
###############################################################################
###############################################################################


###############################################################################
###############################################################################
############################ AUXILIARY FUNCTIONS ##############################

# This function computes the action of clock + decay + obligatory activities
# obligatory = set of obligatory activities
# name = entity concerned
# ls = ls of entity
# us = us of entity
# lambdas = lambdas of entity
# deltas = list of the decays duration of entity
# inputlist = tuples for all other entities
def clockt(obligatory, name, ls, us, lambdas, deltas, inputlist, D):

#    print((obligatory, name, ls, us, lambdas, deltas, inputlist, D))
#    print((name, deltas))
    l1 = ls
    u1 = us
    lambda1 = []

    # progression of time in lambda
    for i in range(0, len(lambdas)):
        lambda1.append(min(lambdas[i] + 1, D))

    # progression of time for u (only for bounded levels)
    if deltas[ls] != 0:
        u1 = us + 1

    # decay
    if deltas[ls] != 0 and us + 1 > deltas[ls]:
        l1 = max(0, ls - 1)
        u1 = 0

    # search of obligatory activities where entity name is in results
    act = []
    for alpha in range(0, len(obligatory)):
        obname = 'beta' + str(alpha)
        if name in obligatory[alpha][2] and \
           inputlist[obname] >= obligatory[alpha][3]:
            act.append(obligatory[alpha])

    # computation of the effect on entity name
    for alpha in range(0, len(act)):
        # check if the obligatory activity is enabled or not
        check = True
        activators = act[alpha][0]
#        print(activators)
        for ent in activators:
            t = inputlist[ent]
            if not(t[0] >= activators[ent] and
               t[2][activators[ent]] >= act[alpha][3]):
                check = False

        inhibitors = act[alpha][1]
        for ent in inhibitors:
            t = inputlist[ent]
            if not(t[0] < inhibitors[ent] and
               t[2][inhibitors[ent]] >= act[alpha][3]):
                check = False

        # if enabled compute the effect
        if check:
#            print(act[alpha])
            z = act[alpha][2][name]
            l1 = max(0, min(l1 + z, len(lambda1) - 1))
#            print(name,ls, z, l1)
            u1 = 0

    # update lambda with the proper dates
    temp = l1 - ls
    if temp > 0:
        for i in range(ls + 1, l1):
            lambda1[i] = 0
    if temp < 0:
        for i in range(l1 + 1, ls):
            lambda1[i] = 0
#    print((name, l1, u1, tuple(lambda1)))
    return(l1, u1, tuple(lambda1))


# This function computes the action of clock on obligatory activities places
# obligatory = set of obligatory activities
# name = obligatory activity under consideration
# w = current value
# inputlist = tuples for all other entities
def clockbetat(obligatory, name, w, inputlist, D):
#    print(name, w, inputlist, D)
#    print(obligatory)
    check = True
    activators = obligatory[name][0]
    for ent in activators:
        t = inputlist[ent]
        if not(t[0] >= activators[ent] and
           t[2][activators[ent]] >= obligatory[name][3]):
            check = False
    inhibitors = obligatory[name][1]
    for ent in inhibitors:
        t = inputlist[ent]
        if not(t[0] < inhibitors[ent] and
           t[2][inhibitors[ent]] >= obligatory[name][3]):
            check = False
        # if enabled compute the effect
    if check and w >= obligatory[name][3]:
        return(0)
    else:
        return(min(w + 1, D))


# this function computes the action on an entity of a potential activity
# name = entity under consideration
# lp, up, lambdap = its values
# R = set of results of the activity
def potentialt(name, lp, up, lambdap, R):
#    print((name, lp, up, lambdap, R))
    # entity is a result?
    if (name in R):
        lambda2 = list(lambdap)
        levelp = max(0, min(len(lambdap) - 1, lp + R[name]))
        change = levelp - lp
        if  change > 0:
            for i in range(lp + 1, levelp + 1):
                lambda2[i] = 0
        if change < 0:
            for i in range(levelp + 1, lp + 1):
                lambda2[i] = 0
        return(levelp, 0, tuple(lambda2))
    else:
        return(lp, up, lambdap)

##############################     END     ####################################
###############################################################################
###############################################################################


####################### MAIN ##################################################

# compute maximal duration of activities
D = 0
for alpha in potential:
    D = max(D, alpha[3])

for alpha in obligatory:
    D = max(D, alpha[3])

n = PetriNet('andy')

n.globals["obligatory"] = obligatory
n.globals["D"] = D
n.globals["clockt"] = clockt
n.globals["clockbetat"] = clockbetat
n.globals["potentialt"] = potentialt

################# Places for entities
for i in range(0, len(entities)):
    name = entities[i][0]
    level = entities[i][1]
    deltas = entities[i][2]
    vector = [0] * len(deltas)
    n.add_place(Place(name, [(level, 0, tuple(vector))]))

################# clock transition
inputlist = ""
#n.globals["inputlist"] = inputlist

n.add_transition(Transition('tc'))

# connect all obligatory clocks
for i in range(0, len(obligatory)):
    # transition name
    obname = 'beta' + str(i)
    n.globals['w' + obname] = Variable('w' + obname)
    # for every obligatory activity connect corresponding place to clock
    n.add_place(Place('p' + obname, [0]))
    n.add_input('p' + obname, 'tc', Variable('w' + obname))
    inputlist += ("('" + obname + "',w" + obname + "),")

# all entities are connected
for i in range(0, len(entities)):
    name = entities[i][0]
    deltas = entities[i][2]
    n.globals["deltas" + name] = deltas
    n.globals[name] = name

    n.globals['l' + name] = Variable('l' + name)
    n.globals['u' + name] = Variable('u' + name)
    n.globals['lambda' + name] = Variable('lambda' + name)

    n.add_input(name, 'tc', Tuple([Variable('l' + name), Variable('u' + name),
                Variable('lambda' + name)]))
    inputlist += ("('" + name + "',[l" + name + ", u" +
                    name + ", lambda" + name + "]),")

inputlist = "dict([" + inputlist[0:-1] + "])"

for i in range(0, len(entities)):
    name = entities[i][0]
    n.add_output(name, 'tc', Expression("clockt(obligatory," + name + ",l"
                 + name + ',u' + name + ',lambda' + name + ',deltas' + name
                 + ',' + inputlist + ',D)'))

for i in range(0, len(obligatory)):
    obname = 'beta' + str(i)
    # for every obligatory activity connect corresponding place to clock
    n.add_output('p' + obname, 'tc', Expression("clockbetat(obligatory,"
                 + str(i) + ",w" + obname + ',' + inputlist + ', D)'))

## potential activities
for i in range(0, len(potential)):
    # transition name
    trname = 'talpha' + str(i)

    # for every potential activity connect corresponding place to clock
    n.add_place(Place('p' + trname, [0]))
    n.add_input('p' + trname, 'tc', Variable('w' + trname))
    n.add_output('p' + trname, 'tc', Expression('min(D, w' + trname + '+1)'))

    activators = potential[i][0]
    inhibitors = potential[i][1]
    results = potential[i][2]
#    print(results)
    n.globals["results" + trname] = results
    duration = potential[i][3]

    # compute entities involved in the activity
    nameactivators = list(activators.keys())
    nameinhib = list(inhibitors.keys())
    nameresults = list(results.keys())
    names = []
    # check they appear only once
    for i in nameactivators:
        names.append(i)
    for i in nameinhib:
        if not (i in activators):
            names.append(i)
    for i in nameresults:
        if not ((i in activators) or (i in inhibitors)):
            names.append(i)

    # compute guard of the activity
    # activity may be executed once every dur
    guard = 'w>=' + str(duration)

    # activators
    for j in range(0, len(nameactivators)):
        spec = nameactivators[j]
        level = str(activators[nameactivators[j]])
        guard += ' and l' + spec + '>= ' + level + ' and lambda' \
                 + spec + '[' + level + ']>=' + str(duration)

    # inhibitors
    for j in range(0, len(nameinhib)):
        spec = nameinhib[j]
        level = str(inhibitors[nameinhib[j]])
        guard += ' and l' + spec + '< ' + level + ' and lambda' + spec \
                 + '[' + level + ']>=' + str(duration)

    n.add_transition(Transition(trname, Expression(guard)))
    n.add_input('p' + trname, trname, Variable('w'))
    n.add_output('p' + trname, trname, Expression('0'))

    # arcs of the transition from and to involved entities
    for j in range(0, len(names)):
        n.add_input(names[j], trname, Tuple([Variable('l' + names[j]),
                    Variable('u' + names[j]), Variable('lambda' + names[j])]))
        n.add_output(names[j], trname, Expression("potentialt(" + names[j]
                      + ",l" + names[j] + ',u' + names[j] + ',lambda' + names[j]
                      + ', results' + trname + ')'))

s = StateGraph(n)
s.build()


def node_attr(state, graph, attr):
    # attr['label'] = str(state)
    marking = graph[state]
    attr["label"] = ":".join(str(list(marking(s))[0][0])
                             for s in ("lacI", "tetR", "cI", "GFP", "gen"))


def edge_attr(trans, mode, attr):
    attr["label"] = trans.name


def arc_attr(arc, attr):
    attr["label"] = ""

######## depict Petri net
n.draw("repress.png", arc_attr=arc_attr)


#s.draw('repressgraph.ps', node_attr=node_attr, edge_attr=edge_attr,
#       engine="dot")

#g = s.draw(None, node_attr=node_attr, edge_attr=edge_attr, engine="dot")
#with open("repressgraph.dot", "w") as out :
    #out.write(g.dot())

#g.render("repressgraph-layout.dot", engine="dot")

### scenario


t=n.transition('tc')
m=t.modes()
print ("tc, ", m)


t=n.transition('talpha0')
m=t.modes()
print ("talpha0, ", m)




t=n.transition('talpha1')
m=t.modes()
print ("talpha1, ", m)




t=n.transition('talpha2')
m=t.modes()
print ("talpha2, ", m)



# t=n.transition('tr0')
# m=t.modes()
# print m
# t.fire(m[0])
# n.draw('repr1.ps')

#t1=n.transition('tc')
#m1=t1.modes()
#print m1
#t1.fire(m1[0])


# n.draw('repr1.ps')

